import spacy
from spacy.matcher import Matcher
from tqdm import tqdm
import os
from utils.util import write_json
import concurrent.futures

class Segmentation:
    def __init__(self, oring_protocol_arr):
        self.oring_protocol_arr = oring_protocol_arr
        self.nlp = spacy.load("en_core_web_trf")

    def segmentation(self, sentences_store_path):
        # 处理文件路径，_adverbial.json 保存通过依存分析删除的目的状语和原句子的映射
        base, _ = os.path.splitext(sentences_store_path)
        removed_store_path = f"{base}_adverbial.json"

        sentences_arr = []
        removed_arr = []

        # 定义单个协议的处理逻辑
        def process_protocol(protocol):
            sentences = []
            removed_phrases = []
            doc = self.nlp(protocol)
            for sent in doc.sents:
                # 检查 sent.text 是否为正常句子，字符数要大于 10
                if len(sent.text) > 10:
                    # 对 sent.text 去掉所有的转义字符
                    sentence_str = self.__remove_escape_characters(sent.text)
                    sentences_tiny, removed_mapping_arr = self.__deep_segmentation(sentence_str)
                    sentences.extend(sentences_tiny)
                    removed_phrases.extend(removed_mapping_arr)
            return sentences, removed_phrases

        # 使用 ThreadPoolExecutor 进行并行处理
        with concurrent.futures.ThreadPoolExecutor() as executor:
            results = list(tqdm(executor.map(process_protocol, self.oring_protocol_arr), total=len(self.oring_protocol_arr)))

        # 将结果收集到 sentences_arr 和 removed_arr 中
        for sentences, removed_phrases in results:
            sentences_arr.append(sentences)
            removed_arr.append(removed_phrases)

        # 写入结果到文件
        write_json(sentences_store_path, sentences_arr)
        write_json(removed_store_path, removed_arr)

    def __remove_escape_characters(self, string):
        # return string.replace("\n", " ").replace("\r", " ").replace("\t", " ").replace("\\", " ").replace("\"", " ").replace("\'", " ").replace("  ", " ").replace("\u201c", " ").replace("\u201d", " ").replace("\u03b6", "ζ").replace("\u00a0", " ").replace("\u00b0C", "°C").replace("\u03a9", "Ω").replace("\u2013", "-").replace("\u03bc", "μ").replace("\u2014", "-").replace("\u2019", "'").replace("\u00b4", "´").replace("\u237a", "⍺").replace("\u03b1", "⍺").replace("\u2103", "°C").replace("\u00d7", "×").strip()
        return string.replace("\n", " ").replace("\r", " ").replace("\t", " ").replace("\\", " ").replace("\"", " ").replace("\'", " ").replace("  ", " ").replace("\u201c", " ").replace("\u201d", " ").replace("\u00a0", " ").replace("\u03a9", "Ω").replace("\u03bc", "μ").replace("\u2014", "-").replace("\u2019", "'").strip()

    def __deep_segmentation(self, sentence):
        matcher = Matcher(self.nlp.vocab)
        pattern1 = [{'LOWER': 'and'}, {'POS': 'VERB'}]
        pattern2 = [{'LOWER': 'and'}, {'LOWER': 'then'}, {'POS': 'VERB'}]
        matcher.add("AND_VERB", [pattern1, pattern2])
        doc = self.nlp(sentence)
        matches = matcher(doc)
        result_sentences = []
        start = 0
        for _, match_start, match_end in matches:
            # 如果当前匹配的开始不是句子的开始，添加上一个匹配结束到当前匹配开始之间的部分
            if start != match_start:
                result_sentences.append(doc[start:match_start].text)
            # 检查匹配模式以适当调整起始位置
            if doc[match_start].text.lower() == 'and':
                if match_start + 1 < len(doc) and doc[match_start + 1].text.lower() == 'then':
                    start = match_start + 2  # 跳过 "and then"
                else:
                    start = match_start + 1  # 跳过 "and"
            else:
                start = match_start
        # 添加最后一个匹配之后的文本
        if start < len(doc):
            result_sentences.append(doc[start:].text)

        res_sents, removed_mapping_arr = self.__parsing_segmentation_2(result_sentences)

        # 检查 result_sentences 的每一个元素是否以符号结尾，如果不是，则在句末添加句号，如果是以句号以外的符号结尾，则将该符号改为句号
        for i in range(len(res_sents)):
            if len(res_sents[i]) > 0 and res_sents[i][-1] == ',':
                res_sents[i] = res_sents[i][:-1] + '.'
            if len(res_sents[i]) > 0 and res_sents[i][-1] not in ['.', '!', '?']:
                res_sents[i] += '.'

        return res_sents, removed_mapping_arr

    def __parsing_segmentation_1(self, sentence_arr):
        res_sents = []
        for sentence in sentence_arr:
            doc_sent = self.nlp(sentence)
            tiny_sents = sentence.split(",")
            dep_tokens = []
            for token in doc_sent:
                if token.dep_ == "ROOT":
                    root = token.text
                if token.head.dep_ == "ROOT" and token.dep_ != "punct":
                    dep_tokens.append(token.text)
            res = [sent for sent in tiny_sents if root in sent or not any(word in sent for word in dep_tokens)]
            res_sents.append(','.join(res).strip())
        
        return res_sents
            
    def __parsing_segmentation_2(self, sentence_arr):
        res_sents = []
        removed_phrases = []
        removed_mapping_arr = []

        for sentence in sentence_arr:
            doc_sent = self.nlp(sentence)
            sentence_processed = []
            phrase_decorated = []
            current_sentence = []
            contains_root = False
            contains_decoration = False

            # 找出根动词和目的状语
            for token in doc_sent:
                current_sentence.append(token.text)

                # 判断是否为根动词
                if token.dep_ == "ROOT":
                    contains_root = True
                
                # 判断是否为目的状语：
                # 介词短语，e.g. In order to..., For ..., Finally, ...
                # 动词短语 e.g. To <verb> ..., ...
                if (token.pos_ == "ADP" and token.dep_ == "prep" and token.head.dep_ == "ROOT") or \
                   (token.pos_ == "ADV" and token.dep_ == "advmod" and token.head.dep_ == "ROOT") or \
                   (token.pos_ == "PART" and token.dep_ == "aux"):
                    contains_decoration = True
                
                if token.text == ",":
                    sentence_text = " ".join(current_sentence).strip()
                    if contains_root or not contains_decoration:
                        sentence_processed.append(sentence_text)
                    else:
                        phrase_decorated.append(sentence_text)

                    current_sentence = []
                    contains_root = False
                    contains_decoration = False
            
            # 处理最后一个短句
            if current_sentence:
                sentence_text = " ".join(current_sentence).strip()
                if contains_root or not contains_decoration:
                    sentence_processed.append(sentence_text)
                else:
                    phrase_decorated.append(sentence_text)
            
            res_sents.append("".join(sentence_processed).strip())
            removed_phrases.append(phrase_decorated)

            # 处理留档被删除的部分
            for phrase, sentence, origin_sentence in zip(removed_phrases, res_sents, sentence_arr):
                if len(phrase) > 0:
                    removed_mapping_arr.append({"removed phrases":phrase, "sentence":sentence, "original sentence":origin_sentence})

        return res_sents, removed_mapping_arr